{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import math,time\n",
    "from datetime import timedelta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#权重\n",
    "def new_W(shape):\n",
    "    W=tf.Variable(tf.truncated_normal(shape,stddev=0.05))\n",
    "    return W\n",
    "\n",
    "#偏置\n",
    "def new_b(length):\n",
    "    b=tf.Variable(tf.constant(0.1,shape=length))\n",
    "    return b\n",
    "\n",
    "#卷积\n",
    "def conv2d(x,W):\n",
    "    conv=tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME')\n",
    "    return conv\n",
    "\n",
    "#最大池化\n",
    "def max_pool_2x2(inputx):\n",
    "    max_pool=tf.nn.max_pool(inputx,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')\n",
    "    return max_pool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting dataset/train-images-idx3-ubyte.gz\n",
      "Extracting dataset/train-labels-idx1-ubyte.gz\n",
      "Extracting dataset/t10k-images-idx3-ubyte.gz\n",
      "Extracting dataset/t10k-labels-idx1-ubyte.gz\n",
      "\n",
      "train set:label_num=55000,img_num=55000\n",
      "test set:label_num=10000,img_num=55000\n",
      "validation set:label_num=5000,img_num=5000\n"
     ]
    }
   ],
   "source": [
    "#加载数据集\n",
    "data=input_data.read_data_sets('dataset/',one_hot=True)\n",
    "\n",
    "#打印一些数据集信息\n",
    "print('\\n' 'train set:label_num={},img_num={}'.format(len(data.train.labels),data.train.num_examples))\n",
    "print('test set:label_num={},img_num={}'.format(len(data.test.labels),data.train.num_examples))\n",
    "print('validation set:label_num={},img_num={}'.format(len(data.validation.labels),data.validation.num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "data.test.cls=np.argmax(data.test.labels,axis=1)\n",
    "#输入值\n",
    "x=tf.placeholder('float',shape=[None,784],name='x')\n",
    "x_image=tf.reshape(x,[-1,28,28,1])\n",
    "\n",
    "#输出值\n",
    "y_true=tf.placeholder('float',shape=[None,10],name='y_true')\n",
    "y_true_cls=tf.argmax(y_true,dimension=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "#conv1 & relu & max_pool\n",
    "layer_conv1={'weights':new_W([5,5,1,32]),'biases':new_b([32])}\n",
    "h_conv1=tf.nn.relu(conv2d(x_image,layer_conv1['weights'])+layer_conv1['biases'])\n",
    "h_pool1=max_pool_2x2(h_conv1)\n",
    "\n",
    "#conv2 & relu & max_pool\n",
    "layer_conv2={'weights':new_W([5,5,32,64]),'biases':new_b([64])}\n",
    "h_conv2=tf.nn.relu(conv2d(h_pool1,layer_conv2['weights'])+layer_conv2['biases'])\n",
    "h_pool2=max_pool_2x2(h_conv2)\n",
    "\n",
    "#fc1-->flat-->relu\n",
    "fc1_layer={'weights':new_W([7*7*64,1024]),'biases':new_b([1024])}\n",
    "h_pool2_plat=tf.reshape(h_pool2,[-1,7*7*64])\n",
    "h_fc1=tf.nn.relu(tf.matmul(h_pool2_plat,fc1_layer['weights'])+fc1_layer['biases'])\n",
    "\n",
    "#Dropout\n",
    "keep_prob=tf.placeholder('float')\n",
    "h_fc1_drop=tf.nn.dropout(h_fc1,keep_prob)\n",
    "\n",
    "#fc2\n",
    "fc2_layer={'weights':new_W([1024,10]),'biases':new_b([10])}\n",
    "\n",
    "#使用sotfmax()做预测\n",
    "y_pred=tf.nn.softmax(tf.matmul(h_fc1_drop,fc2_layer['weights'])+fc2_layer['biases'])\n",
    "y_pred_cls=tf.argmax(y_pred,dimension=1) #显示真实预测数字\n",
    "\n",
    "#使用RMSPropOptimizer()优化交叉熵损失函数\n",
    "cross_entropy=-tf.reduce_mean(y_true * tf.log(y_pred))\n",
    "# cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred))\n",
    "optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(cross_entropy)\n",
    "\n",
    "#计算准确率\n",
    "correct_prediction=tf.equal(y_pred_cls,y_true_cls)\n",
    "accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter:     0,train_acc: 5.00%\n",
      "Iter:   500,train_acc:100.00%\n",
      "Iter:  1000,train_acc:100.00%\n",
      "Iter:  1500,train_acc:100.00%\n",
      "Iter:  2000,train_acc:100.00%\n",
      "Iter:  2500,train_acc:100.00%\n",
      "Iter:  3000,train_acc:100.00%\n",
      "Iter:  3500,train_acc:100.00%\n",
      "Iter:  4000,train_acc:100.00%\n",
      "time:0:00:17\n",
      "test_acc:99.34% (9934/10000)\n"
     ]
    }
   ],
   "source": [
    "# 训练 & 测试\n",
    "with tf.Session() as sess:\n",
    "    init=tf.global_variables_initializer()\n",
    "    sess.run(init)\n",
    "    \n",
    "    #训练\n",
    "    train_batch_size=100\n",
    "    display_num=500\n",
    "    def train(train_epoch):\n",
    "        total_iter=0\n",
    "        start_time=time.time()\n",
    "        \n",
    "        for i in range(total_iter,total_iter+train_epoch):\n",
    "            x_batch,y_true_batch=data.train.next_batch(train_batch_size)\n",
    "            feed_dict_train_op={x:x_batch,y_true:y_true_batch,keep_prob:0.8}\n",
    "            feed_dict_train={x:x_batch,y_true:y_true_batch,keep_prob:1.0}\n",
    "            \n",
    "            sess.run(optimizer,feed_dict=feed_dict_train_op)\n",
    "            \n",
    "            #训练状态\n",
    "            if (i % display_num==0):\n",
    "                #计算训练集的准确率\n",
    "                acc=sess.run(accuracy,feed_dict=feed_dict_train)\n",
    "                #打印消息\n",
    "                msg='Iter:{0:>6},train_acc:{1:6.2%}'\n",
    "                print(msg.format(i,acc))\n",
    "        \n",
    "#         total_iter+=train_epoch\n",
    "        \n",
    "        end_time=time.time()\n",
    "        use_time=end_time-start_time\n",
    "        print('time:'+str(timedelta(seconds=int(round(use_time)))))\n",
    "    \n",
    "    #测试\n",
    "    test_batch_size=100\n",
    "    def test():\n",
    "        num_test=len(data.test.images)\n",
    "        cls_pred=np.zeros(shape=num_test,dtype=np.int)\n",
    "        \n",
    "        i=0\n",
    "        while i<num_test:\n",
    "            j=min(i+test_batch_size,num_test)\n",
    "            img=data.test.images[i:j,:]\n",
    "            labels=data.test.labels[i:j,:]\n",
    "            feed_dict={x:img,y_true:labels,keep_prob:1.0}\n",
    "            cls_pred[i:j]=sess.run(y_pred_cls,feed_dict=feed_dict)\n",
    "            i=j\n",
    "            \n",
    "        cls_true=data.test.cls\n",
    "        correct=(cls_true==cls_pred)\n",
    "        correct_sum=correct.sum()\n",
    "        acc=float(correct_sum)/num_test\n",
    "        \n",
    "        msg='test_acc:{0:.2%} ({1}/{2})'\n",
    "        print(msg.format(acc,correct_sum,num_test))\n",
    "\n",
    "    \n",
    "    train_epoch=4001\n",
    "    train(train_epoch)\n",
    "    test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n激活函数：Leaky-ReLU\\n卷积核3*3、1*1\\n优化器\\n学习率、dropout值'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#更新\n",
    "'''\n",
    "激活函数：ReLU，Leaky-ReLU\n",
    "卷积核3*3、1*1\n",
    "优化器：GradientDescentOptimizer(0.01)，AdadeltaOptimizer(0.01)，AdamOptimizer()，RMSPropOptimizer(learning_rate=0.001)，AdagradOptimizer(learning_rate=0.01)等\n",
    "学习率、dropout值\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
